import torch
from torch import nn
from torchsummary import summary


# 残差块定义
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1conv=False, strides=1):
        super(Residual, self).__init__()
        # 激活函数
        self.ReLU = nn.ReLU()
        # 卷积层
        # 大小为3，填充为1
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=3, padding=1,
                               stride=strides)
        self.conv2 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3, padding=1)
        # BN层  根据输出通道数进行批归一化
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        # 是否使用1*1卷积
        if use_1conv:
            self.conv3 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None

    # 前向传播层
    def forward(self, x):
        y = self.ReLU(self.bn1(self.conv1(x)))  # 第一层卷积、BN规范化、激活函数
        y = self.bn2(self.conv2(y))  # 第二层卷积、BN规范化
        if self.conv3:  # 存在则1*1卷积
            x = self.conv3(x)
        y = self.ReLU(y + x)  # 跳转连接
        return y


class ResNet18(nn.Module):
    def __init__(self, Residual):
        super(ResNet18, self).__init__()
        # PyTorch序列搭建
        # 输入层
        self.b1 = nn.Sequential(
            # 卷积  输入通道数、输出通道数、核尺寸大小、填充、步幅
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
            # 激活
            nn.ReLU(),
            # BN
            nn.BatchNorm2d(64),
            # 最大池化
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        # 第一个卷积块
        self.b2 = nn.Sequential(Residual(64, 64, use_1conv=False, strides=1),
                                Residual(64, 64, use_1conv=False, strides=1))

        self.b3 = nn.Sequential(Residual(64, 128, use_1conv=True, strides=2),
                                Residual(128, 128, use_1conv=False, strides=1))

        self.b4 = nn.Sequential(Residual(128, 256, use_1conv=True, strides=2),
                                Residual(256, 256, use_1conv=False, strides=1))

        self.b5 = nn.Sequential(Residual(256, 512, use_1conv=True, strides=2),
                                Residual(512, 512, use_1conv=False, strides=1))

        self.b6 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),  # 全局平均池化
                                nn.Flatten(),  # 平展层
                                nn.Linear(512, 5))  # 全连接层

    # 前向传播
    def forward(self, x):
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        x = self.b6(x)
        return x


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet18(Residual).to(device)
    print(summary(model, (1, 224, 224)))
